OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852kevalmorabia97 merged 11 commits intomainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds FP8 post-processing converting TensorRT-specific FP8 QDQ nodes to native ONNX QDQ, introduces ONNX Cast utilities for redundant-cast removal and targeted FP16 conversion, and updates Torch ONNX export to apply these utilities with a reordered FP16/quantization pipeline. Changes
Sequence Diagram(s)sequenceDiagram
participant Exporter as FP8 Exporter
participant Graph as ONNX Graph
participant TRT as TRT_FP8 Nodes
participant Native as Native ONNX Ops
participant Cleaner as Graph Cleaner
Exporter->>Graph: scan for TRT_FP8QuantizeLinear / TRT_FP8DequantizeLinear
Graph->>TRT: identify TRT_FP8QuantizeLinear nodes
Exporter->>Graph: for each TRT_FP8QuantizeLinear -> create `zero_point` const if missing
Exporter->>Graph: replace TRT_FP8QuantizeLinear with QuantizeLinear (set saturate)
Graph->>TRT: identify TRT_FP8DequantizeLinear nodes
Exporter->>Graph: replace TRT_FP8DequantizeLinear with DequantizeLinear
Exporter->>Cleaner: invoke cleanup & topological sort
Cleaner->>Graph: remove unused nodes, fix edges, toposort
Cleaner->>Native: graph now uses native ONNX QDQ nodes
Exporter->>Exporter: export cleaned ONNX model
Note over Exporter,Cleaner: logger.info/debug traces conversions
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 104-140: The post_process function's docstring mentions updating
GELU nodes to tanh approximation and inserting Cast nodes after Sqrt, but the
implementation in post_process only converts
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to
QuantizeLinear/DequantizeLinear; either remove or revise those docstring lines
to reflect current behavior, or implement the missing steps: locate GELU nodes
in graph.nodes and replace/modify them to the tanh-approx variant, and insert
Cast nodes immediately after Sqrt nodes' outputs; reference post_process,
TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, GELU, and Sqrt when making the
change.
- Around line 119-126: The FP8 zero-point tensor zp_tensor is missing explicit
shape metadata; update the creation of zp_tensor (used to build zero_point and
appended to node.inputs) to set its dims explicitly (e.g., call
zp_tensor.dims.extend([1]) for a 1-element tensor) so it matches other tensors
created in this module (see the FP8 weights tensor creation) and ensures ONNX
runtimes receive shape info.
In `@modelopt/onnx/utils.py`:
- Around line 1314-1349: In change_casts_to_fp16, only modify Cast nodes that
actually cast from FP32: for each Cast node (node.op_type == "Cast") look up the
source tensor name node.input[0] in graph.initializer, graph.input,
graph.value_info or graph.output to get its element_type and only change the
node.attribute "to" from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 if
the source dtype is FLOAT; also avoid changing Casts that are FP16->FP32 and add
a debug log entry when you modify a Cast (include node.name or node.output[0]
and original->new dtypes) to aid debugging.
🧹 Nitpick comments (1)
modelopt/onnx/utils.py (1)
1218-1261: Consider edge case where first Cast has multiple consumers.The function checks
len(node.outputs[0].outputs) != 1(line 1231) to ensure the first Cast's output goes to exactly one node. However, this may be overly restrictive. If the first Cast feeds into a duplicate second Cast AND other nodes, you could still remove the duplicate Cast while preserving the connection to other consumers. The current logic skips this optimization opportunity.This is a minor optimization opportunity and the current implementation is safe.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #852 +/- ##
==========================================
- Coverage 70.07% 70.06% -0.02%
==========================================
Files 221 221
Lines 25499 25603 +104
==========================================
+ Hits 17869 17939 +70
- Misses 7630 7664 +34 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@ajrasane can you please add the before and after accuracy results in the PR description? I.e: with FP8 custom Q/DQ nodes vs with FP8 native Q/DQ nodes. Thanks! |
|
Let's also add this change to the Changelog file. |
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
There was a problem hiding this comment.
Can you please elaborate the goal/need of this function? Thanks!
There was a problem hiding this comment.
This is because after the convert_float_to_float16() function, one of the inputs for these nodes is FP16, while the other is FP32. Hence we run into a compilation issue with TensorRT. To fix this, I manually update them here for these operators.
There was a problem hiding this comment.
Got it, thanks for the explanation. Can you please update the docstring to give a bit more details? Thanks!
There was a problem hiding this comment.
@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?
There was a problem hiding this comment.
I run into an error while building the engine with TensorRT:
[03/12/2026-21:06:15] [E] Error[9]: Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [myelin_graph.h:attachExceptionMsgToGraph:1146] MyelinCheckException: operand.h:456: CHECK(is_tensor()) failed. In compileGraph at optimizer/myelin/codeGenerator.cpp:1421
[03/12/2026-21:06:15] [E] Error[10]: IBuilder::buildSerializedNetworkToStream: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[x_cast_to_fp16...(Unnamed Layer* 1752) [ElementWise]]}. In computeCosts at optimizer/common/tactic/optimizer.cpp:4115)
I also remember that you had previously mentioned that autocast is not supposed to be used after quantization as it would need a separate design. Hence I removed it from here. Let me know if that is no more the case.
9b30f17 to
e2abd9d
Compare
Accuracy looks good, any idea why perf is slower after this PR? Also, can you please specify which model these numbers are for? Thanks. |
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
There was a problem hiding this comment.
@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?
| logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}") | ||
|
|
||
| if removed_count > 0: | ||
| graph.cleanup().toposort() |
There was a problem hiding this comment.
I recall some issues with toposort.
If you see any failures do to it, we can probably omit, _bypass_cast maintains node sorting.
AutoCast's unit testing covers this part well, and indeed, I see there's quite a few failures with this refactor. |
0186223 to
788313f
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
1151-1172:⚠️ Potential issue | 🟠 MajorThis shared cast cleanup can undo output-name preservation.
_cleanup()already fixes network output names before this call, butonnx_utils.remove_redundant_casts()bypasses output casts by replacinggraph.outputswith the cast input. For redundant casts on model outputs, that reverts the exported output tensor name to the pre-cast name and can trip_sanity_check()or break the public I/O contract.
🧹 Nitpick comments (1)
modelopt/torch/_deploy/utils/torch_onnx.py (1)
62-71: Scope thisonnxconverter_commonworkaround to the conversion call.Patching the module at import time changes behavior process-wide, and
suppress(AttributeError)hides every upstreamAttributeError, not just the known list/attr bug. A temporary patch aroundconvert_float_to_float16()is much safer, and it avoids making this module import brittle if the upstream symbol changes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 62 - 71, The current import-time monkey-patch of _f16_module.remove_unnecessary_cast_node (using _original_remove_unnecessary_cast_node and _patched_remove_unnecessary_cast_node) is global and hides all AttributeError via suppress(AttributeError); instead, scope the workaround only around the call to convert_float_to_float16(): before calling convert_float_to_float16() save the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal wrapper that only catches the specific list/attribute error, call convert_float_to_float16(), and finally restore the original function in a try/finally so the patch is temporary and does not swallow unrelated AttributeErrors or affect the rest of the process.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: In FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 576-599: The model_metadata is built too early and can become
stale after graph rewrites; move the metadata creation so it runs after all ONNX
mutations: after quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.
- Around line 581-588: The FP16 export path should not use torch.autocast during
tracing because you already perform explicit post-export conversion with
convert_float_to_float16; update the autocast logic so the
torch.autocast("cuda") context is only entered when weights_dtype == "bf16" (and
not when weights_dtype == "fp16"), i.e., change the condition that currently
enables autocast for weights_dtype != "fp32" to specifically check for "bf16"
and leave the FP16 path to rely solely on convert_float_to_float16; keep the
convert_float_to_float16 call for FP16 unchanged.
---
Nitpick comments:
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 62-71: The current import-time monkey-patch of
_f16_module.remove_unnecessary_cast_node (using
_original_remove_unnecessary_cast_node and
_patched_remove_unnecessary_cast_node) is global and hides all AttributeError
via suppress(AttributeError); instead, scope the workaround only around the call
to convert_float_to_float16(): before calling convert_float_to_float16() save
the original _f16_module.remove_unnecessary_cast_node, replace it with a minimal
wrapper that only catches the specific list/attribute error, call
convert_float_to_float16(), and finally restore the original function in a
try/finally so the patch is temporary and does not swallow unrelated
AttributeErrors or affect the rest of the process.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a86e0e09-2562-47a8-9985-98fa922fd2f0
📥 Commits
Reviewing files that changed from the base of the PR and between 2ebf0a2 and 788313fa4a998c8f18315b736310f700b29a6a21.
📒 Files selected for processing (6)
CHANGELOG.rstmodelopt/onnx/autocast/precisionconverter.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/export/nvfp4_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.py
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/utils.py
| # Convert TRT_FP8QuantizeLinear to native QuantizeLinear | ||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8QuantizeLinear": | ||
| node.op = "QuantizeLinear" | ||
| # Add FP8 zero_point if not present | ||
| if len(node.inputs) == 2: | ||
| # Create FP8 zero point constant | ||
| zp_tensor = onnx.TensorProto() | ||
| zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| zp_tensor.dims.extend([1]) # 1-element tensor | ||
| zp_tensor.raw_data = b"\x00" # Zero in FP8 | ||
| zp_values = LazyValues(zp_tensor) | ||
| zero_point = gs.Constant(node.name + "_zero_point", zp_values) | ||
| node.inputs.append(zero_point) | ||
| # Add saturate attribute for FP8 | ||
| node.attrs["saturate"] = 1 | ||
| logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear") | ||
|
|
||
| # Convert TRT_FP8DequantizeLinear to native DequantizeLinear | ||
| for node in graph.nodes: | ||
| if node.op == "TRT_FP8DequantizeLinear": | ||
| node.op = "DequantizeLinear" | ||
| logger.debug( | ||
| f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" | ||
| ) | ||
|
|
||
| graph.cleanup().toposort() |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Using the official ONNX operator and type documentation, what is the minimum ai.onnx opset that supports FLOAT8E4M3FN with QuantizeLinear and DequantizeLinear, including the saturate attribute on QuantizeLinear?
💡 Result:
Minimum ai.onnx opset: 19.
- QuantizeLinear first supports
tensor(float8e4m3fn)(via itsT2type constraint) in since_version = 19, and this is also the version that defines thesaturateattribute (applying to float8 quantization). [1] - DequantizeLinear first supports
tensor(float8e4m3fn)(via itsT1type constraint) in since_version = 19. [2]
Sources:
[1] QuantizeLinear (ONNX operator doc, version 19) (onnx.ai)
[2] DequantizeLinear (ONNX operator doc, version 19) (onnx.ai)
Citations:
- 1: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html?utm_source=openai
- 2: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
🏁 Script executed:
# Check the fp8_exporter.py file to understand the full context
# Look for the function containing this code and any opset checks
fd "fp8_exporter.py" -x wc -l {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 153
🏁 Script executed:
# Read the file to understand the function context and any opset validation
cat -n modelopt/onnx/export/fp8_exporter.py | head -200Repository: NVIDIA/Model-Optimizer
Length of output: 7470
🏁 Script executed:
# Search for opset-related code and checks in the file
rg -n "opset|onnx_opset" modelopt/onnx/export/fp8_exporter.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Find the base exporter to understand the parent class
fd "base_exporter.py" modelopt/onnx/exportRepository: NVIDIA/Model-Optimizer
Length of output: 103
🏁 Script executed:
# Search for calls to post_process and get_onnx_bytes_and_metadata
rg -n "post_process|get_onnx_bytes_and_metadata" modelopt/onnx/export --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 942
🏁 Script executed:
# Check if there's opset validation in the base exporter
cat -n modelopt/onnx/export/base_exporter.py | head -300Repository: NVIDIA/Model-Optimizer
Length of output: 2365
🏁 Script executed:
# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata|onnx_opset" modelopt/onnx --type py -B 2 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 823
🏁 Script executed:
# Look for where FP8QuantExporter is instantiated/used
rg -n "FP8QuantExporter" modelopt/onnx --type py -B 2 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1792
🏁 Script executed:
# Check the quantizer file that likely calls the exporter
fd "quantizer.py" modelopt/onnx -x wc -l {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Search more broadly for onnx_opset and opset parameter usage
rg -n "onnx_opset|opset_version" modelopt/onnx --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 3244
🏁 Script executed:
# Look for quantizer or export related files that handle opset
fd "quantizer|export" modelopt/onnx --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 235
🏁 Script executed:
# Check the quantizer base class
rg -n "class.*Quantizer" modelopt/onnx --type py -A 10 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the ort_patching.py file around the FP8 opset check
cat -n modelopt/onnx/quantization/ort_patching.py | sed -n '220,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 1445
🏁 Script executed:
# Check where _check_opset_version is called and in what context
rg -n "_check_opset_version|check_opset_version" modelopt/onnx/quantization --type py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1253
🏁 Script executed:
# Look at the quantize.py file to understand the full flow
cat -n modelopt/onnx/quantization/quantize.py | sed -n '140,185p'Repository: NVIDIA/Model-Optimizer
Length of output: 2716
🏁 Script executed:
# Search for get_onnx_bytes_and_metadata function
rg -n "get_onnx_bytes_and_metadata" modelopt/onnxRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the quantize_mode constants to understand FP8 mode handling
rg -n "QDQ_PRECISION_MIN_OPSET|fp8" modelopt/onnx/quantization/quantize.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 3436
🏁 Script executed:
# Check if there's a separate export entry point that may bypass opset checks
rg -n "def.*export|get_onnx_bytes" modelopt/onnx --type py | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 473
🏁 Script executed:
# Search for quantize_fp8 function
rg -n "def quantize_fp8|def quantize" modelopt/onnx/quantization/fp8.py -A 10 | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 604
🏁 Script executed:
# Look for where process_model is called
rg -n "process_model" modelopt/onnx --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 453
🏁 Script executed:
# Check the fp8.py file to see if it calls FP8QuantExporter.post_process
cat -n modelopt/onnx/quantization/fp8.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 7314
🏁 Script executed:
# Continue reading fp8.py to see where FP8QuantExporter is used
cat -n modelopt/onnx/quantization/fp8.py | sed -n '150,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 5087
🏁 Script executed:
# Search for where FP8QuantExporter.process_model is called
rg -n "\.process_model|FP8QuantExporter" modelopt/onnx --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 1639
🏁 Script executed:
# Check for any test files or examples that use FP8 export
fd "test.*fp8|fp8.*test" modelopt --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 235
🏁 Script executed:
# Search for where process_model is called with FP8QuantExporter
rg -n "process_model\|FP8QuantExporter" modelopt/onnxRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look in llm_export_utils since it showed up in opset search
cat -n modelopt/onnx/llm_export_utils/export_utils.py | sed -n '155,170p'Repository: NVIDIA/Model-Optimizer
Length of output: 379
🏁 Script executed:
# Check the entire quantize function end to see if it calls process_model or post_process
cat -n modelopt/onnx/quantization/fp8.py | tail -100Repository: NVIDIA/Model-Optimizer
Length of output: 5585
🏁 Script executed:
# Search for get_onnx_bytes_and_metadata more carefully
rg -n "get_onnx_bytes_and_metadata\|get_onnx_bytes"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look for uses of FP8QuantExporter more comprehensively
rg -rn "FP8QuantExporter" modelopt/Repository: NVIDIA/Model-Optimizer
Length of output: 667
🏁 Script executed:
# Check if there's an export function that uses the exporter classes
rg -n "def export\|def get_onnx" modelopt/onnx -B 2 -A 5 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the torch_onnx.py to see how exporters are used
cat -n modelopt/torch/_deploy/utils/torch_onnx.py | grep -A 10 -B 10 "FP8QuantExporter\|n\|onnx_exporters"Repository: NVIDIA/Model-Optimizer
Length of output: 31963
🏁 Script executed:
# Check the actual file content more carefully
grep -n "class.*Exporter" modelopt/onnx/export/*.pyRepository: NVIDIA/Model-Optimizer
Length of output: 564
🏁 Script executed:
# Look at what might be calling process_model on FP8QuantExporter
rg -n "process_model\|post_process" modelopt/torchRepository: NVIDIA/Model-Optimizer
Length of output: 48
Add opset >= 19 validation before FP8 Q/DQ conversion in FP8QuantExporter.post_process().
The code converts TRT custom ops to native QuantizeLinear/DequantizeLinear with FLOAT8E4M3FN and the saturate attribute, but does not verify that the model's opset is >= 19 (the minimum required for these operators). When callers invoke get_onnx_bytes_and_metadata() with onnx_opset < 19 on a FP8-quantized model, the post-processor will silently generate an invalid ONNX model instead of upgrading the opset or raising an error.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, In
FP8QuantExporter.post_process(), before converting
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to native
QuantizeLinear/DequantizeLinear using FLOAT8E4M3FN and the saturate attribute,
validate the model opset version is >= 19; locate the method
FP8QuantExporter.post_process and check the graph/model opset (opset_import or
graph.model.opset_import) and if opset < 19 either raise a clear exception
(e.g., ValueError) telling callers to use onnx_opset >= 19 or programmatically
upgrade the model opset to 19 before performing the conversions (and then
proceed with the existing replacement logic for TRT_FP8QuantizeLinear and
TRT_FP8DequantizeLinear).
| onnx_opt_graph = quantize_weights(model, onnx_opt_graph) | ||
|
|
||
| if dq_only: | ||
| onnx_opt_graph = qdq_to_dq(onnx_opt_graph) | ||
|
|
||
| try: | ||
| # TODO: Single-precision torch model assumed | ||
| param_dtype = next(model.parameters()).dtype | ||
| except StopIteration: | ||
| param_dtype = torch.float32 | ||
| if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: | ||
| if is_int4_quantized(model) or is_mxfp8_quantized(model): | ||
| assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" | ||
| onnx_opt_graph = convert_float_to_float16( | ||
| onnx_opt_graph, | ||
| keep_io_types=False, | ||
| disable_shape_infer=True, | ||
| check_fp16_ready=False, | ||
| ) | ||
| else: | ||
| onnx_opt_graph = convert_to_f16( | ||
| onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False | ||
| ) | ||
| if weights_dtype == "fp16": | ||
| onnx_opt_graph = convert_float_to_float16( | ||
| onnx_opt_graph, | ||
| keep_io_types=False, | ||
| disable_shape_infer=True, | ||
| check_fp16_ready=False, | ||
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) | ||
|
|
||
| # TensorRT expects all scales to be postive | ||
| onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) | ||
| onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) | ||
|
|
||
| # TensorRT expects all scales to be postive | ||
| onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) | ||
|
|
||
| # TODO: Remove manual ir_version change once ORT supports ir_version 11 | ||
| # Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version | ||
| onnx_opt_graph.ir_version = 10 |
There was a problem hiding this comment.
Rebuild model_metadata after the final ONNX rewrites.
Lines 572-574 capture metadata before quantize_weights(), convert_float_to_float16(), change_casts_to_fp16(), and remove_redundant_casts(). Those passes add/remove nodes and can rewrite I/O tensors, so the returned metadata can drift from the serialized model bytes.
🛠️ Proposed fix
Move the metadata creation block below the last graph mutation.
- model_metadata = create_model_metadata(
- tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
- )Then re-add it after onnx_opt_graph.ir_version = 10:
+ model_metadata = create_model_metadata(
+ tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 576 - 599, The
model_metadata is built too early and can become stale after graph rewrites;
move the metadata creation so it runs after all ONNX mutations: after
quantize_weights(), qdq_to_dq(), convert_float_to_float16(),
change_casts_to_fp16(), remove_redundant_casts(), and
replace_zero_scale_with_smallest_nonzero(), and ensure you rebuild it after
setting onnx_opt_graph.ir_version = 10 so the returned metadata matches the
final serialized graph bytes.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/onnx/utils.py (1)
1507-1511:⚠️ Potential issue | 🟠 MajorGuard FP32→FP16 cast rewrites by input dtype.
At Line 1507, this rewrites any
Cast(to=FLOAT)feeding target ops, including casts from non-FP32 sources. That can silently change behavior by removing intentional upcasts.💡 Proposed fix
for node in model.graph.node: if node.op_type != "Cast": continue @@ if not feeds_target: continue + cast_input_type = _get_tensor_type_by_name(model, node.input[0]) + if cast_input_type != onnx.TensorProto.FLOAT: + continue + # Check if Cast is to FP32, and change to FP16 for attr in node.attribute: if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: attr.i = onnx.TensorProto.FLOAT16 break
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
921-923: Prefer publiconnx_utilsAPIs over underscored helpers across modules.Using
_get_tensor_type_by_name,_bypass_cast_node, and_is_same_type_castfrom another module couples this class to private implementation details. Exposing public wrappers inmodelopt/onnx/utils.pywould make this boundary safer.Also applies to: 942-942, 1097-1102
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/autocast/precisionconverter.py` around lines 921 - 923, The code currently calls private helpers onnx_utils._get_tensor_type_by_name, onnx_utils._bypass_cast_node, and onnx_utils._is_same_type_cast from precisionconverter.py; replace these calls with public wrapper APIs (e.g., get_tensor_type_by_name, bypass_cast_node, is_same_type_cast) exported from modelopt/onnx/utils.py and update precisionconverter.py to call those public names (also update the other locations that use the underscored helpers around lines referenced). Add the small public wrapper implementations in modelopt/onnx/utils.py that delegate to the existing private functions so other modules use the stable public API and then run tests to ensure behavior is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/utils.py`:
- Around line 1296-1322: The current _is_sequential_cast only compares the two
Cast target types; modify it to also fetch the data type of the original source
feeding the first Cast (e.g., inspect the producer of node.input[0] or its
ValueInfo/initializer) and verify that this source type equals the second cast's
target type (the value returned by get_cast_to_type(next_node)) before returning
True; this extra check ensures that when _bypass_cast_node rewires the graph the
source type is compatible with the second Cast. Use get_consumer_nodes,
get_cast_to_type and the node.input[0] producer lookup to locate and compare
types.
---
Nitpick comments:
In `@modelopt/onnx/autocast/precisionconverter.py`:
- Around line 921-923: The code currently calls private helpers
onnx_utils._get_tensor_type_by_name, onnx_utils._bypass_cast_node, and
onnx_utils._is_same_type_cast from precisionconverter.py; replace these calls
with public wrapper APIs (e.g., get_tensor_type_by_name, bypass_cast_node,
is_same_type_cast) exported from modelopt/onnx/utils.py and update
precisionconverter.py to call those public names (also update the other
locations that use the underscored helpers around lines referenced). Add the
small public wrapper implementations in modelopt/onnx/utils.py that delegate to
the existing private functions so other modules use the stable public API and
then run tests to ensure behavior is unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 188d9d44-97fd-4011-abb0-37b5f6bdbf27
📥 Commits
Reviewing files that changed from the base of the PR and between 788313fa4a998c8f18315b736310f700b29a6a21 and 40ce80f69a78b1880e57e7b3f685ec8301a14095.
📒 Files selected for processing (5)
modelopt/onnx/autocast/graphsanitizer.pymodelopt/onnx/autocast/precisionconverter.pymodelopt/onnx/autocast/utils.pymodelopt/onnx/utils.pytests/unit/onnx/autocast/test_precisionconverter.py
| def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: | ||
| assert node.op_type == "Cast" | ||
| output_type = get_cast_to_type(node) | ||
|
|
||
| # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed | ||
| # Cast to low precision -> cast to high precision affects precision and should not be removed | ||
| precision_order = [ | ||
| onnx.TensorProto.DOUBLE, | ||
| onnx.TensorProto.FLOAT, | ||
| onnx.TensorProto.FLOAT16, | ||
| onnx.TensorProto.BFLOAT16, | ||
| ] | ||
| consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"] | ||
|
|
||
| # If the first cast has additional consumers, we should not remove it | ||
| if len(consumers) != 1: | ||
| return False | ||
|
|
||
| next_node = consumers[0] | ||
| first_cast_type = output_type | ||
| second_cast_type = get_cast_to_type(next_node) | ||
|
|
||
| return ( | ||
| first_cast_type in precision_order | ||
| and second_cast_type in precision_order | ||
| and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and read the function
cat -n modelopt/onnx/utils.py | sed -n '1290,1330p'Repository: NVIDIA/Model-Optimizer
Length of output: 1919
🏁 Script executed:
# Also check if the helper function _get_tensor_type_by_name exists (used in proposed fix)
rg "_get_tensor_type_by_name" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 216
🏁 Script executed:
# Search for get_cast_to_type and get_consumer_nodes to understand them
rg "def get_cast_to_type|def get_consumer_nodes" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 211
🏁 Script executed:
# Check if there are tests for this function
fd "test.*\.py" | xargs rg "_is_sequential_cast" 2>/dev/null || echo "No tests found"Repository: NVIDIA/Model-Optimizer
Length of output: 80
🏁 Script executed:
# Find where _is_sequential_cast is called
rg "_is_sequential_cast" modelopt/onnx/utils.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 440
🏁 Script executed:
# Look at the context around _bypass_cast_node to understand what happens when removal occurs
cat -n modelopt/onnx/utils.py | sed -n '1325,1360p'Repository: NVIDIA/Model-Optimizer
Length of output: 2048
🏁 Script executed:
# Search for comments or documentation about cast removal logic
rg "sequential.*cast|cast.*removal|bypass.*cast" modelopt/onnx/utils.py -i -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 1472
🏁 Script executed:
# Look for any test files that might test cast optimization
fd "test.*\.py" modelopt/onnx/ | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Get more context about _is_same_type_cast to understand the overall pattern
cat -n modelopt/onnx/utils.py | sed -n '1285,1295p'Repository: NVIDIA/Model-Optimizer
Length of output: 572
🏁 Script executed:
# Look for the full remove_redundant_casts function to understand the overall flow
rg "def remove_redundant_casts" modelopt/onnx/utils.py -A 30Repository: NVIDIA/Model-Optimizer
Length of output: 1393
🏁 Script executed:
# Check if there's any documentation or related test about safe cast removal
fd "*.py" modelopt/onnx/ | xargs grep -l "cast\|Cast" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 349
🏁 Script executed:
# Look at how _get_tensor_type_by_name works and what it can return
rg "def _get_tensor_type_by_name" modelopt/onnx/utils.py -A 20Repository: NVIDIA/Model-Optimizer
Length of output: 991
🏁 Script executed:
# Check for any tests or usage examples that show what types node.input can have
fd ".py" modelopt/onnx/ --type f | xargs grep -l "remove_redundant_casts\|_is_sequential_cast" 2>/dev/nullRepository: NVIDIA/Model-Optimizer
Length of output: 133
🏁 Script executed:
# Look at the overall structure of the optimization and find if there are any tests
fd ".py" --type f modelopt/onnx/ | xargs grep -l "test.*cast\|cast.*test" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check the precisionconverter file to see if it has any tests or usage
cat -n modelopt/onnx/autocast/precisionconverter.py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 2246
🏁 Script executed:
# Let me understand the logic better by looking at a concrete example
# First, let me see if there are any examples or documentation about this function
rg "sequential.*cast" modelopt/ -i --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 555
🏁 Script executed:
# Check if the source type is ever checked anywhere in the function
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p' | grep -E "input|src|source"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the broader context to understand how the function comment relates to the logic
cat -n modelopt/onnx/utils.py | sed -n '1296,1323p'Repository: NVIDIA/Model-Optimizer
Length of output: 1323
🏁 Script executed:
# Let me trace through a concrete scenario to understand the issue better
# Looking at the comment and logic more carefully
cat -n modelopt/onnx/utils.py | sed -n '1296,1325p'Repository: NVIDIA/Model-Optimizer
Length of output: 1415
🏁 Script executed:
# Understand what happens when we remove the first cast by looking at bypass logic
cat -n modelopt/onnx/utils.py | sed -n '1325,1355p'Repository: NVIDIA/Model-Optimizer
Length of output: 1685
Sequential-cast removal must validate the original source type against the final cast target.
The function only compares the two cast target types (lines 1318–1322) but doesn't verify that removing the first cast preserves the input type to the second cast. When _bypass_cast_node rewires the graph, it directly connects the original source to the second cast. If the source type differs from what the second cast expects, this changes behavior.
For example:
FLOAT16 → cast(FLOAT) → cast(FLOAT16)passes the current check (FLOAT ≤ FLOAT16)- After removal:
FLOAT16 → cast(FLOAT16)— but cast2 was designed for FLOAT input, causing incorrect behavior
Add a check to ensure the source type matches the second cast's target type before removal.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1296 - 1322, The current
_is_sequential_cast only compares the two Cast target types; modify it to also
fetch the data type of the original source feeding the first Cast (e.g., inspect
the producer of node.input[0] or its ValueInfo/initializer) and verify that this
source type equals the second cast's target type (the value returned by
get_cast_to_type(next_node)) before returning True; this extra check ensures
that when _bypass_cast_node rewires the graph the source type is compatible with
the second Cast. Use get_consumer_nodes, get_cast_to_type and the node.input[0]
producer lookup to locate and compare types.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
40ce80f to
05c33b2
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (3)
modelopt/torch/_deploy/utils/torch_onnx.py (2)
576-576:⚠️ Potential issue | 🟠 MajorRebuild
model_metadataafter the last graph rewrite.Starting with
quantize_weights(), this function mutates node names, tensor dtypes, and Q/DQ structure aftermodel_metadatahas already been captured above. The returned metadata can therefore describe a different graph than the bytes written to disk.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` at line 576, model_metadata is captured before quantize_weights mutates the ONNX graph (node names, tensor dtypes, Q/DQ structure), so the saved metadata can be out of sync with onnx_opt_graph; after the final graph rewrite (the call to quantize_weights that returns onnx_opt_graph) re-run the metadata extraction routine (the same function that produced model_metadata earlier) to rebuild model_metadata from the mutated onnx_opt_graph so the metadata matches the bytes written to disk; update any downstream uses to reference the new model_metadata variable produced after quantize_weights.
581-592:⚠️ Potential issue | 🟠 MajorThe FP16 path is still “autocast +
convert_float_to_float16()”.With this new rewrite block,
weights_dtype="fp16"still traces undertorch.autocast("cuda")earlier in the function, so FP16 export is not actually usingconvert_float_to_float16()instead of autocast. That makes the exported graph depend on both mechanisms and undermines the stated pipeline change.🛠️ Suggested earlier-function change
- use_torch_autocast = not ( - is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32" - ) + use_torch_autocast = weights_dtype == "bf16" and not ( + is_fp4_quantized(model) or is_mxfp8_quantized(model) + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 581 - 592, The current FP16 branch still runs under torch.autocast earlier, so the export uses both autocast and convert_float_to_float16; modify the export flow so when weights_dtype == "fp16" you do NOT run the model export under torch.autocast("cuda") (or skip the autocast context) so the graph is produced in full-precision then transformed only by convert_float_to_float16 and change_casts_to_fp16; locate the autocast usage earlier in this module (torch.autocast or the export context manager) and add a conditional to bypass it when weights_dtype == "fp16", ensuring convert_float_to_float16/remove_redundant_casts are the sole FP16 transformations.modelopt/onnx/export/fp8_exporter.py (1)
121-147:⚠️ Potential issue | 🟠 MajorReject native FP8 Q/DQ export below opset 19.
This conversion emits native
QuantizeLinear/DequantizeLinearwithFLOAT8E4M3FNandsaturate, but it still doesn't guard againstonnx_opset < 19. Exporting FP8 with a lower opset will silently produce an invalid model instead of failing early or upgrading the opset.🛠️ Suggested guard
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + opset = next( + (op.version for op in onnx_model.opset_import if op.domain in ("", "ai.onnx")), + 0, + ) + if opset < 19: + raise ValueError("Native FP8 ONNX Q/DQ requires ai.onnx opset >= 19.") + logger.info("Post-processing FP8 quantized model") graph = gs.import_onnx(onnx_model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 121 - 147, The conversion loop for TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx opset < 19; detect the model/export opset (e.g., an existing variable or by examining graph.opset or a passed opset parameter) before performing the conversions in the loops that change node.op to "QuantizeLinear"/"DequantizeLinear" and set FLOAT8E4M3FN/saturate, and if the opset is less than 19 either raise a clear exception or upgrade the opset to >= 19 before modifying nodes (apply this check where you manipulate node.op and node.attrs in the FP8 exporter function that iterates graph.nodes).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 127-134: The injected zero-point Constant currently uses node.name
which may be empty and cause duplicate tensor names; change the naming for the
Constant created from zp_tensor/zp_values/zero_point to a guaranteed-unique
string (e.g., combine node.name when present with a unique suffix such as a
uuid4 or the node's memory id or an incrementing counter, or use an ONNX/graph
helper that returns a unique name) so each FP8 zero-point Constant has a
distinct tensor name even for unnamed TRT FP8 Q nodes.
In `@modelopt/onnx/utils.py`:
- Around line 1262-1276: The helper _get_tensor_type_by_name must also handle
producer-only tensors emitted by nodes (e.g., a Constant node that produces a
tensor but has no value_info entry); modify _get_tensor_type_by_name to iterate
model.graph.node and when a node.output matches tensor_name, if node.op_type ==
"Constant" extract the TensorProto from the node attribute (attribute named
"value") and return its data_type (or elem_type equivalent), otherwise
skip/continue so non-materialized producer-only tensors do not cause an
exception and allow remove_redundant_casts() to fold Constant->Cast patterns.
- Around line 1403-1417: The two cast-removal checks can both match the same
node causing duplicate removal; make them mutually exclusive by ensuring once a
node is handled by _is_sequential_cast(onnx_model, node) (where you call
_bypass_cast_node and append to nodes_to_remove) you skip the subsequent
_is_foldable_constant_cast_pattern check (e.g., use an elif or continue) so you
only call _bypass_cast_node, _convert_constant_values and append to
nodes_to_remove once; update the block containing _is_sequential_cast,
_is_foldable_constant_cast_pattern, _bypass_cast_node, _convert_constant_values,
get_producer_nodes, nodes_to_remove, and logger.debug accordingly.
- Around line 1499-1511: The current logic uses any(...) on tensor_to_consumers
to set Casts to FP16 even if only one consumer is in target_op_types, which
incorrectly changes shared Cast outputs; modify the check so the Cast is
retargeted only when the entire fanout is eligible (replace the any(...) test
with an all(...) test, and treat empty consumer lists as ineligible/skip), then
keep the existing loop over node.attribute that looks for attr.name == "to" and
change attr.i from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 only when
that all-consumers condition holds.
---
Duplicate comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 121-147: The conversion loop for
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear must guard against onnx opset <
19; detect the model/export opset (e.g., an existing variable or by examining
graph.opset or a passed opset parameter) before performing the conversions in
the loops that change node.op to "QuantizeLinear"/"DequantizeLinear" and set
FLOAT8E4M3FN/saturate, and if the opset is less than 19 either raise a clear
exception or upgrade the opset to >= 19 before modifying nodes (apply this check
where you manipulate node.op and node.attrs in the FP8 exporter function that
iterates graph.nodes).
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Line 576: model_metadata is captured before quantize_weights mutates the ONNX
graph (node names, tensor dtypes, Q/DQ structure), so the saved metadata can be
out of sync with onnx_opt_graph; after the final graph rewrite (the call to
quantize_weights that returns onnx_opt_graph) re-run the metadata extraction
routine (the same function that produced model_metadata earlier) to rebuild
model_metadata from the mutated onnx_opt_graph so the metadata matches the bytes
written to disk; update any downstream uses to reference the new model_metadata
variable produced after quantize_weights.
- Around line 581-592: The current FP16 branch still runs under torch.autocast
earlier, so the export uses both autocast and convert_float_to_float16; modify
the export flow so when weights_dtype == "fp16" you do NOT run the model export
under torch.autocast("cuda") (or skip the autocast context) so the graph is
produced in full-precision then transformed only by convert_float_to_float16 and
change_casts_to_fp16; locate the autocast usage earlier in this module
(torch.autocast or the export context manager) and add a conditional to bypass
it when weights_dtype == "fp16", ensuring
convert_float_to_float16/remove_redundant_casts are the sole FP16
transformations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7704d15e-7bdf-4b56-b725-8f52b69a07e2
📥 Commits
Reviewing files that changed from the base of the PR and between 40ce80f69a78b1880e57e7b3f685ec8301a14095 and 05c33b2.
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/onnx/autocast/graphsanitizer.pymodelopt/onnx/autocast/precisionconverter.pymodelopt/onnx/autocast/utils.pymodelopt/onnx/export/fp8_exporter.pymodelopt/onnx/export/nvfp4_exporter.pymodelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.pytests/unit/onnx/autocast/test_precisionconverter.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/onnx/autocast/test_precisionconverter.py
| # Create FP8 zero point constant | ||
| zp_tensor = onnx.TensorProto() | ||
| zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN | ||
| zp_tensor.dims.extend([1]) # 1-element tensor | ||
| zp_tensor.raw_data = b"\x00" # Zero in FP8 | ||
| zp_values = LazyValues(zp_tensor) | ||
| zero_point = gs.Constant(node.name + "_zero_point", zp_values) | ||
| node.inputs.append(zero_point) |
There was a problem hiding this comment.
Use a guaranteed-unique tensor name for the injected zero point.
node.name is optional in ONNX, so node.name + "_zero_point" can collapse to the same tensor name for multiple unnamed TRT FP8 Q nodes. That can make the exported graph invalid due to duplicate tensor names.
🛠️ Safer naming
- zero_point = gs.Constant(node.name + "_zero_point", zp_values)
+ zero_point = gs.Constant(f"{node.outputs[0].name}_zero_point", zp_values)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/export/fp8_exporter.py` around lines 127 - 134, The injected
zero-point Constant currently uses node.name which may be empty and cause
duplicate tensor names; change the naming for the Constant created from
zp_tensor/zp_values/zero_point to a guaranteed-unique string (e.g., combine
node.name when present with a unique suffix such as a uuid4 or the node's memory
id or an incrementing counter, or use an ONNX/graph helper that returns a unique
name) so each FP8 zero-point Constant has a distinct tensor name even for
unnamed TRT FP8 Q nodes.
Review CommentsThanks for the PR — the core idea of replacing TRT-specific FP8 QDQ nodes with native ONNX ops is solid, and the refactoring to centralize cast utilities in 1. BF16 regression in
|
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (3)
modelopt/onnx/utils.py (3)
1420-1435:⚠️ Potential issue | 🟠 MajorMissing
continueafter_is_sequential_castcan cause double processing.A node matching both
_is_sequential_castand_is_foldable_constant_cast_pattern(e.g.,Constant → Cast → Cast) will be processed by both blocks, leading to duplicate entries innodes_to_removeand potential errors during removal.🛠️ Proposed fix
# Find sequential casts that don't change precision if _is_sequential_cast(onnx_model, node): nodes_to_remove.append(node) _bypass_cast_node(onnx_model, node) logger.debug(f"Found removable double-cast: {node.name}") + continue # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. if _is_foldable_constant_cast_pattern(onnx_model, node):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/utils.py` around lines 1420 - 1435, The sequential-cast branch (_is_sequential_cast) can fall through and be reprocessed by the foldable Constant->Cast branch, causing duplicate entries in nodes_to_remove and double-modification of the same node; after handling a sequential cast (where you call _bypass_cast_node and log), add a control-flow break (e.g., a continue) to skip further checks for that node so it isn't also processed by _is_foldable_constant_cast_pattern, ensuring nodes_to_remove, _bypass_cast_node, _convert_constant_values, get_producer_nodes and the node aren't acted on twice.
1516-1528:⚠️ Potential issue | 🟠 MajorUsing
any()can incorrectly change shared Cast outputs affecting non-target consumers.When a Cast node's output feeds multiple consumers and only some are in
target_op_types, changing the Cast to FP16 affects all consumers, potentially breaking non-target branches.🛠️ Proposed fix - only change if ALL consumers are target ops
# Check if this Cast outputs to a target op type cast_output = node.output[0] consumers = tensor_to_consumers.get(cast_output, []) - feeds_target = any(c.op_type in target_op_types for c in consumers) + feeds_target = bool(consumers) and all(c.op_type in target_op_types for c in consumers) if not feeds_target: continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/utils.py` around lines 1516 - 1528, The current logic uses any() and flips a Cast to FP16 even when some consumers are non-target, which can break those branches; change the feeds_target condition to only be true when there is at least one consumer and ALL consumers are in target_op_types (i.e., replace feeds_target = any(...) with feeds_target = bool(consumers) and all(c.op_type in target_op_types for c in consumers)), then proceed with the existing node.attribute loop that checks attr.name == "to" and attr.i to change FLOAT to FLOAT16.
1262-1277:⚠️ Potential issue | 🟡 MinorFunction does not handle Constant node outputs that lack
value_infoentries.When a Constant node produces an output tensor that isn't registered in
value_info,initializer,input, oroutput, this function raises an exception. This can occur before_is_foldable_constant_cast_patterngets a chance to handle it.Consider checking node outputs for Constant nodes:
🛠️ Proposed fix
def _get_tensor_type_by_name(model: onnx.ModelProto, tensor_name: str): """Get the tensor element type. Searches value_info, initializers, inputs, and outputs.""" for vi in model.graph.value_info: if vi.name == tensor_name: return vi.type.tensor_type.elem_type for init in model.graph.initializer: if init.name == tensor_name: return init.data_type for inp in model.graph.input: if inp.name == tensor_name: return inp.type.tensor_type.elem_type for out in model.graph.output: if out.name == tensor_name: return out.type.tensor_type.elem_type + # Check Constant node outputs + for node in model.graph.node: + if node.op_type == "Constant" and tensor_name in node.output: + for attr in node.attribute: + if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR: + return attr.t.data_type + break raise Exception(f"did not find tensor {tensor_name}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/utils.py` around lines 1262 - 1277, _get_tensor_type_by_name currently misses Constant node outputs that aren't listed in value_info/initializer/input/output; update it to also scan model.graph.node for nodes with op_type == "Constant" whose output list contains tensor_name, then find the node AttributeProto with name "value" (the TensorProto stored on the Constant) and return its data_type (attr.t.data_type). Keep the existing checks first (value_info/initializer/input/output), add the Constant-node check before raising the Exception, and otherwise keep the same error behavior.
🧹 Nitpick comments (2)
modelopt/onnx/utils.py (1)
1289-1294: RedundantNonecheck on list comprehension result.
input_typesis always a list (from the list comprehension), soinput_types is not Noneis alwaysTrue. This check can be simplified.♻️ Proposed fix
def _is_same_type_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: assert node.op_type == "Cast" input_types = [_get_tensor_type_by_name(model, inp) for inp in node.input] output_type = get_cast_to_type(node) - return all(inp_type == output_type for inp_type in input_types) and input_types is not None + return bool(input_types) and all(inp_type == output_type for inp_type in input_types)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/utils.py` around lines 1289 - 1294, The function _is_same_type_cast has a redundant "input_types is not None" check because input_types is always a list from the list comprehension; update the return statement in _is_same_type_cast (referencing variables input_types and output_type) to simply return all(inp_type == output_type for inp_type in input_types), removing the unnecessary None check.modelopt/torch/_deploy/utils/torch_onnx.py (1)
573-576:model_metadatacreated before graph mutations may have staleonnx_node_names.Metadata is captured at line 573-575 before
quantize_weights, FP16 conversion, andremove_redundant_casts. While I/O names and shapes should remain stable,onnx_node_nameswill be stale afterremove_redundant_castsremoves Cast nodes.If
onnx_node_namesis used for validation or debugging, consider moving metadata creation after line 606 (after all mutations).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 573 - 576, The code creates model_metadata via create_model_metadata(...) before performing graph mutations, which can leave onnx_node_names stale; move the model_metadata = create_model_metadata(tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model) call to after the transformations (after quantize_weights, FP16 conversion/convert_model_to_fp16, and remove_redundant_casts) so that onnx_opt_graph reflects the final node set; ensure you reference the same onnx_opt_graph and model when recreating metadata so onnx_node_names are accurate for validation/debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/utils.py`:
- Around line 1296-1323: In _is_sequential_cast, add validation of the original
source tensor type before approving removal: fetch the source type via
_get_tensor_type_by_name(model, node.input[0]) and compute its index in the
existing precision_order; then ensure source_index <= index(first_cast_type) <=
index(second_cast_type) (using get_cast_to_type for first/second casts and
get_consumer_nodes to find the next Cast) and only return True when this
three-way ordering holds; if the source type cannot be resolved or any type not
in precision_order, return False.
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 63-73: Add a detailed comment above the monkey-patch for
_f16_module.remove_unnecessary_cast_node explaining the upstream bug being
worked around: reference the upstream issue/PR URL (or issue number) that
describes the "downstream_node is a list" failure, explicitly state the exact
AttributeError message being suppressed (e.g., "'list' object has no attribute
'some_attr'") and under what conditions it occurs, and document the minimum
onnxconverter_common version (or commit/PR) that fixes it so this workaround can
be removed; include the unique symbols _original_remove_unnecessary_cast_node,
_patched_remove_unnecessary_cast_node, and suppress(AttributeError) in the
comment so readers can quickly locate the patched code.
---
Duplicate comments:
In `@modelopt/onnx/utils.py`:
- Around line 1420-1435: The sequential-cast branch (_is_sequential_cast) can
fall through and be reprocessed by the foldable Constant->Cast branch, causing
duplicate entries in nodes_to_remove and double-modification of the same node;
after handling a sequential cast (where you call _bypass_cast_node and log), add
a control-flow break (e.g., a continue) to skip further checks for that node so
it isn't also processed by _is_foldable_constant_cast_pattern, ensuring
nodes_to_remove, _bypass_cast_node, _convert_constant_values, get_producer_nodes
and the node aren't acted on twice.
- Around line 1516-1528: The current logic uses any() and flips a Cast to FP16
even when some consumers are non-target, which can break those branches; change
the feeds_target condition to only be true when there is at least one consumer
and ALL consumers are in target_op_types (i.e., replace feeds_target = any(...)
with feeds_target = bool(consumers) and all(c.op_type in target_op_types for c
in consumers)), then proceed with the existing node.attribute loop that checks
attr.name == "to" and attr.i to change FLOAT to FLOAT16.
- Around line 1262-1277: _get_tensor_type_by_name currently misses Constant node
outputs that aren't listed in value_info/initializer/input/output; update it to
also scan model.graph.node for nodes with op_type == "Constant" whose output
list contains tensor_name, then find the node AttributeProto with name "value"
(the TensorProto stored on the Constant) and return its data_type
(attr.t.data_type). Keep the existing checks first
(value_info/initializer/input/output), add the Constant-node check before
raising the Exception, and otherwise keep the same error behavior.
---
Nitpick comments:
In `@modelopt/onnx/utils.py`:
- Around line 1289-1294: The function _is_same_type_cast has a redundant
"input_types is not None" check because input_types is always a list from the
list comprehension; update the return statement in _is_same_type_cast
(referencing variables input_types and output_type) to simply return
all(inp_type == output_type for inp_type in input_types), removing the
unnecessary None check.
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 573-576: The code creates model_metadata via
create_model_metadata(...) before performing graph mutations, which can leave
onnx_node_names stale; move the model_metadata =
create_model_metadata(tree_spec_input, tree_spec_output, input_none_names,
onnx_opt_graph, model) call to after the transformations (after
quantize_weights, FP16 conversion/convert_model_to_fp16, and
remove_redundant_casts) so that onnx_opt_graph reflects the final node set;
ensure you reference the same onnx_opt_graph and model when recreating metadata
so onnx_node_names are accurate for validation/debugging.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ba926212-e494-4ed1-a553-8657733f661f
📒 Files selected for processing (2)
modelopt/onnx/utils.pymodelopt/torch/_deploy/utils/torch_onnx.py
| def _is_sequential_cast(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: | ||
| assert node.op_type == "Cast" | ||
| output_type = get_cast_to_type(node) | ||
|
|
||
| # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed | ||
| # Cast to low precision -> cast to high precision affects precision and should not be removed | ||
| precision_order = [ | ||
| onnx.TensorProto.DOUBLE, | ||
| onnx.TensorProto.FLOAT, | ||
| onnx.TensorProto.FLOAT16, | ||
| onnx.TensorProto.BFLOAT16, | ||
| ] | ||
| consumers = [n for n in get_consumer_nodes(model, node.output[0]) if n.op_type == "Cast"] | ||
|
|
||
| # If the first cast has additional consumers, we should not remove it | ||
| if len(consumers) != 1: | ||
| return False | ||
|
|
||
| next_node = consumers[0] | ||
| first_cast_type = output_type | ||
| second_cast_type = get_cast_to_type(next_node) | ||
|
|
||
| return ( | ||
| first_cast_type in precision_order | ||
| and second_cast_type in precision_order | ||
| and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) | ||
| ) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if there are unit tests covering _is_sequential_cast edge cases
rg -l "_is_sequential_cast|remove_redundant_casts" --glob "*test*.py"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
#!/bin/bash
# Search for calls to _is_sequential_cast and understand its usage context
rg "_is_sequential_cast" -B 5 -A 5 modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1187
🏁 Script executed:
#!/bin/bash
# Look at how node inputs are accessed in the file
rg "node\.input\[0\]" modelopt/onnx/utils.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 239
🏁 Script executed:
#!/bin/bash
# Check if there are helper functions to get input nodes or types
rg "def get_input|def.*input" modelopt/onnx/utils.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1040
🏁 Script executed:
#!/bin/bash
# Look at helper function _get_tensor_type_by_name used in the file
rg "_get_tensor_type_by_name" modelopt/onnx/utils.py -A 10 | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 1166
🏁 Script executed:
#!/bin/bash
# Check the full context of _is_sequential_cast including surrounding functions
sed -n '1280,1330p' modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2090
🏁 Script executed:
#!/bin/bash
# Search for get_producer_nodes to understand how to get source node info
rg "def get_producer_nodes" modelopt/onnx/utils.py -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 325
🏁 Script executed:
#!/bin/bash
# Get the exact signature and full implementation of _is_sequential_cast to confirm
sed -n '1296,1323p' modelopt/onnx/utils.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 1323
🏁 Script executed:
#!/bin/bash
# Check if there are any recent commits or comments about source type validation
git log --oneline --all -S "source_type" -- modelopt/onnx/utils.py | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 48
Add source type validation to prevent incorrect cast optimization.
The function only compares the first and second cast output types but ignores the original source type. This creates a semantic bug where incorrect optimizations can occur.
Example: If source is BFLOAT16 → Cast to DOUBLE → Cast to FLOAT16, the function currently approves removal of the first cast (since DOUBLE index ≤ FLOAT16 index), resulting in a direct BFLOAT16 → FLOAT16 cast. This differs from the original two-cast path through higher precision.
Fix: Retrieve the source type using _get_tensor_type_by_name(model, node.input[0]) and validate that source_type_index <= first_cast_type_index <= second_cast_type_index before approving removal.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/utils.py` around lines 1296 - 1323, In _is_sequential_cast, add
validation of the original source tensor type before approving removal: fetch
the source type via _get_tensor_type_by_name(model, node.input[0]) and compute
its index in the existing precision_order; then ensure source_index <=
index(first_cast_type) <= index(second_cast_type) (using get_cast_to_type for
first/second casts and get_consumer_nodes to find the next Cast) and only return
True when this three-way ordering holds; if the source type cannot be resolved
or any type not in precision_order, return False.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
4dda7bf to
9e3a35a
Compare
|
@cjluo-nv, |
cjluo-nv
left a comment
There was a problem hiding this comment.
Review
The refactoring direction (moving graph utilities to onnx/utils.py) is sound, the _build_tensor_type_map O(1) optimization is nice, and accuracy results show negligible regression. A few correctness issues to address before approving:
1. Bug: Missing continue in remove_redundant_casts() (onnx/utils.py)
The _is_sequential_cast and _is_foldable_constant_cast_pattern checks are not mutually exclusive. A Constant -> Cast -> Cast pattern satisfies both. Without a continue after the sequential-cast branch, _bypass_cast_node gets called twice on the same node and it gets appended to nodes_to_remove twice. The second bypass operates on already-modified graph connections, which could corrupt the graph.
# After this block, add `continue`:
if _is_sequential_cast(onnx_model, node):
nodes_to_remove.append(node)
_bypass_cast_node(onnx_model, node)
logger.debug(f"Found removable double-cast: {node.name}")
continue # <-- missing2. _get_tensor_type_by_name() can throw on Constant-produced tensors
This helper only searches value_info, initializers, inputs, and outputs. But Constant node outputs are often not materialized in value_info. When _is_same_type_cast calls this on a Cast fed by a Constant, it will raise Exception("did not find tensor ..."). Consider also scanning Constant node attributes in _build_tensor_type_map, or handling the exception gracefully in _is_same_type_cast.
3. change_casts_to_fp16() is overly broad
- Uses
any()— if a Cast output feeds both aConcatand a non-target op, the Cast is flipped to FP16 for all consumers, potentially breaking the non-target branch. Consider usingall()or only retargeting when the entire fanout is eligible. - Doesn't verify source type — a Cast from FP64→FP32 would get incorrectly changed to FP64→FP16.
4. quantize_weights() is now unconditional (torch_onnx.py)
Previously gated by is_int4_quantized or is_fp4_quantized or is_mxfp8_quantized. Does quantize_weights() no-op safely for non-quantized models?
5. Minor: Zero-point tensor naming (fp8_exporter.py)
node.name + "_zero_point" could produce duplicate names if multiple TRT_FP8QuantizeLinear nodes have empty names. Using node.output[0] as the name root would be safer.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
LGTM — the latest commit addresses the key review items (missing continue, Constant tensor type lookup, change_casts_to_fp16 scoping, and quantize_weights no-op guard).
Two minor nits for follow-up:
fp8_exporter.py:node.name + "_zero_point"could produce duplicate names if nodes have empty names — consider usingnode.output[0]as the name root instead.torch_onnx.py: bareprint("No quantization exporters found...")— considerlogger.info()for consistency.
Neither is blocking.
What does this PR do?
Type of change:
New feature
Overview:
Testing
Results:
Before replacement:
After replacement:
Before your PR is "Ready for review"
Summary by CodeRabbit
New Features
Improvements
Tests
Changelog